import torch

a = torch.Tensor([
    [
        [1, 2, 3],
        [4, 5, 6]
    ],
    [
        [7, 8, 9],
        [10, 11, 12]
    ]
])
print(a.shape)
print(a.sum(dim=0))
print(a.sum(dim=1))
print(a.sum(dim=2))